{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Array Broadcasting: What is x+y?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-04T17:31:51.572156Z",
     "start_time": "2021-01-04T17:31:51.561157Z"
    },
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "x = np.array([[0], [1], [2]])\n",
    "y = np.array([[3, 4, 5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-04T17:31:52.771356Z",
     "start_time": "2021-01-04T17:31:52.754353Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0]\n",
      " [1]\n",
      " [2]]\n",
      "[[3 4 5]]\n"
     ]
    }
   ],
   "source": [
    "print(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-04T17:34:05.749379Z",
     "start_time": "2021-01-04T17:34:05.735377Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[3, 4, 5],\n",
       "       [4, 5, 6],\n",
       "       [5, 6, 7]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x + y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Array Broadcasting\n",
    "\n",
    "A sensible way of doing elementwise operations on arrays of different (but compatible) shapes.\n",
    "\n",
    "$$\\pmatrix{0 & 1  & 2\\\\ 3 & 4 & 5} + \\pmatrix{1 & 2 & 3} = \\pmatrix{1 & 3  & 5\\\\ 4 & 6 & 8}$$\n",
    "\n",
    "$$\\pmatrix{0 & 1  & 2\\\\ 3 & 4 & 5} + \\pmatrix{1 \\\\ 2} = \\pmatrix{1 & 2 & 3 \\\\ 5 & 6 & 7}$$\n",
    "\n",
    "It works with plus, minus, times, exponentiation, min/max, and many more elementwise operations. Search the numpy docs for the word \"broadcast\" to see if it is supported."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Shape Compatibility Rules\n",
    "\n",
    "1. If x, y have a different number of dimensions, prepend 1's to the shape of the shorter.\n",
    "2. Any axis of length 1 can be repeated (broadcast) to the length of the other vector's length in that axis\n",
    "3. All other axes must have matching lengths.\n",
    "\n",
    "Use these rules to compute whether the arrays are compatible and, if so, the broadcasted shape."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Example 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "x.shape == (2, 3)\n",
    "\n",
    "y.shape == (2, 3)  # compatible\n",
    "y.shape == (2, 1)  # compatible\n",
    "y.shape == (1, 3)  # compatible\n",
    "y.shape == (3,)  # compatible\n",
    "\n",
    "# results in (2, 3) shape\n",
    "\n",
    "y.shape == (3, 2)  # NOT compatible\n",
    "y.shape == (2,)  # NOT compatible"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "# Example 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x.shape == (1000, 256, 256, 256)\n",
    "\n",
    "y.shape == (1000, 256, 256, 256)  # compatible\n",
    "y.shape == (1000, 1, 256, 256)  # compatible\n",
    "y.shape == (1000, 1, 1, 256)  # compatible\n",
    "y.shape == (1, 256, 256, 256)  # compatible\n",
    "y.shape == (1, 1, 256, 1)  # compatible\n",
    "\n",
    "# results in (1000, 256, 256, 256) shape\n",
    "\n",
    "y.shape == (1000, 256, 256)  # NOT compatible"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "# Example 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x.shape == (1, 2, 3, 5, 1, 11, 1, 17)\n",
    "y.shape ==          (1, 7, 1,  1, 17)  # compatible\n",
    "\n",
    "# results in shape (1, 2, 3, 5, 7, 11, 1, 17)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Once shapes match, use for-loop to understand\n",
    "\n",
    "For any axis with length 1, use the only possible value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1   2   3]\n",
      " [ 13  14  15]\n",
      " [106 107 108]]\n",
      "[[  1   2   3]\n",
      " [ 13  14  15]\n",
      " [106 107 108]]\n"
     ]
    }
   ],
   "source": [
    "x = np.array([[0, 1, 2],\n",
    "              [3, 4, 5],\n",
    "              [6, 7, 8]])\n",
    "y = np.array([1, 10, 100]).reshape(3, 1)\n",
    "\n",
    "print(x + y)\n",
    "\n",
    "# x     (3, 3)\n",
    "# y     (3, 1)\n",
    "shape = (3, 3)\n",
    "out = np.empty(shape, dtype=int)\n",
    "N0, N1 = shape\n",
    "for i in range(N0):\n",
    "    for j in range(N1):\n",
    "        # in the dimension that y only has 1 element, just use it\n",
    "        out[i, j] = x[i, j] + y[i, 0]\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Just omit variables for prepended 1's."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[  1  11 102]\n",
      "  [  4  14 105]\n",
      "  [  7  17 108]]\n",
      "\n",
      " [[ 10  20 111]\n",
      "  [ 13  23 114]\n",
      "  [ 16  26 117]]]\n",
      "[[[  1  11 102]\n",
      "  [  4  14 105]\n",
      "  [  7  17 108]]\n",
      "\n",
      " [[ 10  20 111]\n",
      "  [ 13  23 114]\n",
      "  [ 16  26 117]]]\n"
     ]
    }
   ],
   "source": [
    "x = np.array([[[0, 1, 2],\n",
    "               [3, 4, 5],\n",
    "               [6, 7, 8]],\n",
    "              [[9, 10, 11],\n",
    "               [12, 13, 14],\n",
    "               [15, 16, 17]]])  # shape (2, 3, 3)\n",
    "y = np.array([1, 10, 100])     # shape       (3,)\n",
    "\n",
    "print(x + y)\n",
    "\n",
    "# align and prepend\n",
    "# x     (2, 3, 3)\n",
    "# y     (1, 1, 3)\n",
    "shape = (2, 3, 3)\n",
    "out = np.empty(shape, dtype=int)\n",
    "N0, N1, N2 = shape\n",
    "for i in range(N0):\n",
    "    for j in range(N1):\n",
    "        for k in range(N2):\n",
    "            # leave off prepended indices of y\n",
    "            out[i, j, k] = x[i, j, k] + y[k]\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Both arrays can have broadcasted axes, not just one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3 4 5]\n",
      " [4 5 6]\n",
      " [5 6 7]]\n",
      "[[3 4 5]\n",
      " [4 5 6]\n",
      " [5 6 7]]\n"
     ]
    }
   ],
   "source": [
    "x = np.array([[0], [1], [2]])  # (3, 1)\n",
    "y = np.array([[3, 4, 5]])     # (1, 3)\n",
    "\n",
    "print(x + y)\n",
    "\n",
    "shape = (3, 3)\n",
    "out = np.empty(shape, dtype=int)\n",
    "N0, N1 = shape\n",
    "for i in range(N0):\n",
    "    for j in range(N1):\n",
    "        out[i, j] = x[i, 0] + y[0, j]\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Exercise for you\n",
    "\n",
    "You have N=1000 images of size WxH=32x32, where each image has C=3 channels (red, green, and blue pixel value from 0 to 255) for each location in the image.\n",
    "\n",
    "Suppose you have the images stored in an array x of size (N, C, W, H) == (1000, 3, 32, 32)\n",
    "\n",
    "What array y would you multiply by to scale every red pixel by 2, every green pixel by 3, and every blue pixel by 4 (don't worry about overflow)?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Answer\n",
    "y = np.array([2, 3, 4]).reshape(1, 3, 1, 1)\n",
    "# or\n",
    "# y = np.array([2, 3, 4]).reshape(3, 1, 1)\n",
    "\n",
    "# Understand with for-loops\n",
    "x = np.ones(shape, dtype=np.uint8) # for example\n",
    "shape = (1000, 3, 32, 32)\n",
    "out = np.empty(shape, dtype=np.uint8)\n",
    "N, C, W, H = shape\n",
    "for n in range(N):\n",
    "    for channel in range(C):\n",
    "        for w in range(W):\n",
    "            for h in range(H):\n",
    "                out[n, channel, w, h] = x[n, channel, w, h] * y[0, channel, 0, 0]"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
